author:
- Yu Sun
- Xinhao Li
- Karan Dalal
- Jiarui Xu
- Arjun Vikram
- Genghan Zhang
- Yann Dubois
- Xinlei Chen
- Xiaolong Wang
- Sanmi Koyejo
- Tatsunori Hashimoto
- Carlos Guestrin
submission:
year: "2024"
file:
- "[[01. Learning to (Learn at Test Time) - RNNs with Expressive Hidden States.pdf|05. Learning to (Learn at Test Time) - RNNs with Expressive Hidden States]]"
related:
tags:
- Test-Time-Learning
- Test-Time-Training
review date: 2024-11-12
Summary
Test-Time Training (TTT) 레이어는 선형 복잡도를 가지면서도 긴 컨텍스트에서 효과적인 성능을 보이는 새로운 시퀀스 모델링 레이어를 제안 [Abstract, p.1]
Self-attention은 긴 컨텍스트에서 좋은 성능을 보이지만 이차 복잡도를 가지고, 기존 RNN은 선형 복잡도를 가지지만 긴 컨텍스트에서 성능이 제한적 [Section 1, p.2]
TTT 레이어는 히든 스테이트를 머신러닝 모델로 구성하고 업데이트 규칙을 자기지도학습 단계로 설정하여 선형 복잡도와 표현력 있는 히든 스테이트를 동시에 달성 [Section 1, p.2-3]
TTT-Linear와 TTT-MLP 두 가지 구현을 제안했으며, Transformer와 Mamba와 비교했을 때 125M에서 1.3B 파라미터 규모에서 동등하거나 더 좋은 성능을 보임 [Abstract, p.1]
Self-attention과 기존 RNN의 장단점을 분석하고 긴 컨텍스트에서 효과적인 새로운 아키텍처의 필요성을 제기 [Section 1, p.2]
히든 스테이트를 머신러닝 모델로 구성하고 업데이트를 자기지도학습으로 하는 Test-Time Training (TTT) 레이어를 제안 [Section 2.1, p.3-4]
Mini-batch TTT와 dual form을 통해 하드웨어 효율성을 개선 [Section 2.4-2.5, p.7-8]
TTT-Linear와 TTT-MLP 두 가지 구현을 제안하고 각각의 특성을 분석 [Section 2.6-2.7, p.10-12]
실험을 통해 제안한 방법이 긴 컨텍스트에서 기존 방법들보다 우수한 성능을 보임을 일부 입증 [Section 3, p.13-17]
기존 RNN은 선형 복잡도를 가지지만 히든 스테이트의 표현력 한계로 긴 컨텍스트에서 성능이 제한적이다. ["Self-attention can also be viewed from the perspective above, except that its hidden state, commonly known as the Key-Value (KV) cache, is a list that grows linearly with t.", Section 1, p.2]
Self-attention은 긴 컨텍스트에서 좋은 성능을 보이지만 이차 복잡도를 가진다. ["Unlike self-attention, RNN layers have to compress context into a hidden state of fixed size.", Section 1, p.2]
선형 복잡도를 유지하면서 표현력 있는 히든 스테이트를 가진 새로운 시퀀스 모델링 레이어가 필요 ["To remain both efficient and expressive in long context, we need a better compression heuristic.", Section 1, p.2]
RNN 레이어들은 고정된 크기의 히든 스테이트에 컨텍스트를 압축해야 함 ["All sequence modeling layers can be viewed from the perspective of storing historic context into a hidden state", Section 2, p.3]
Self-attention은 Key-Value 캐시를 통해 모든 히든 컨텍스트를 명시적으로 저장 ["The hidden state explicitly stores all historic context without compression, making self-attention more expressive than RNN layers for long context.", Section 2, p.4]
자기지도학습은 대규모 훈련 데이터를 모델 가중치로 효과적으로 압축 가능 ["The process of parametric learning can be viewed as compressing a massive training set into the weights of a model.", Section 2.1, p.4]
기존 RNN의 히든 스테이트 표현력 한계 극복 [Section 1, p.2]
Self-attention의 이차 복잡도 문제 해결 [Section 1, p.2]
효율적인 하드웨어 활용을 위한 최적화 필요 [Section 2.5, p.7-8]
히든 스테이트를 머신러닝 모델로 구성 ["Our key idea is to make the hidden state itself a model f with weights W", Section 2.1, p.4]
TTT Layer는 히든 스테이트를 머신러닝 모델(f)로 구성하고, 파라미터(W)를 통해 컨텍스트 정보를 저장함 ["The hidden state st is now equivalent to Wt, the weights of a model f", Section 2.1, p.4]
모든 시퀀스 데이터에 대해 테스트 타임에 학습이 이뤄짐으로써 각 시퀀스에 맞는 최적의 파라미터를 찾음 ["Even at test time, our new layer still trains a different sequence of weights W1,...,WT for every input sequence", Section 2.1, p.5]
Inner loop(TTT)와 Outer loop(네트워크 학습)의 이중 학습 구조를 가짐 ["We refer to training the larger network as the outer loop, and training W within each TTT layer as the inner loop", Section 2.2, p.6]
업데이트 규칙을 자기지도학습 단계로 설정 ["and the update rule a step of self-supervised learning", Section 2.1, p.4]
- 입력 토큰 xt를 corrupted input x̃t로 변환하고 이를 다시 복원하는 task를 학습 ["One choice of ℓ is reconstructing xt itself. To make the learning problem nontrivial, we first process xt into a corrupted input x̃t", Section 2.3, p.6-7]
- θK, θV, θQ 세 개의 학습 가능한 projection matrices를 도입:
- Input Token xt가 들어오면 3가지 뷰로 투영되는데, 이 중 Model Weights (Wt)는 Training view와 Label view를 사용해 self-supervised learning으로 업데이트 / 최종 Output (zt) 생성은 업데이트된 weight로 Test View를 처리하여 진행
- <svg viewBox="0 0 800 500" xmlns="http://www.w3.org/2000/svg">
출력함수 f에 대해 TTT-Linear와 TTT-MLP 두 가지 구현 제안 ["We propose two variants of TTT layers – TTT-Linear and TTT-MLP", Section 2.7, p.12]
TTT-Linear: 히든 스테이트가 선형 모델(f(x) = Wx)로 구성됨
TTT-MLP: 히든 스테이트가 2층 MLP로 구성됨, 더 표현력이 높음 (2-layer MLP with GELU activation)
둘 다 Layer Normalization과 residual connection을 포함 ["For TTT-Linear, flin(x) = Wx, where W is square. For TTT-MLP, fMLP has two layers similar to the MLPs in Transformers", Section 2.7, p.12]
Mini-batch TTT를 통한 병렬화 ["Our proposed solution – mini-batch gradient descent", Section 2.4, p.7]
Dual form을 통한 하드웨어 효율성 향상 ["We call this procedure the dual form, in contrast to the primal form", Section 2.5, p.7-8]
메모리 효율적인 그래디언트 체크포인팅 ["A standard technique to save memory in this scenario is gradient checkpointing", Appendix C, p.31]
시간 축으로의 체크포인팅:
메모리-계산 트레이드오프:
메모리 효율성:
유연성:
*** 특히 긴 시퀀스를 처리할 때, 일반적인 layer 방향의 체크포인팅과 달리 시간 축으로 적용되어 TTT의 특성을 잘 활용할 수 있게 됨